{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "6db2e2b4",
      "metadata": {
        "id": "6db2e2b4"
      },
      "source": [
        "## Build and Compose Conditioners\n",
        "\n",
        "### Overview\n",
        "Protein design via Chroma is highly customizable and programmable. Our robust Conditioner framework enables automatic conditional sampling tailored to a diverse array of protein specifications. This can involve either restraints (which bias the distribution of states using classifier guidance) or constraints (that directly limit the scope of the underlying sampling process). For a detailed explanation, refer to Supplementary Appendix M in our paper. We offer a variety of pre-defined conditioners, including those for managing substructure, symmetry, shape, semantics, and even natural-language prompts (see `chroma.layers.structure.conditioners`). These conditioners can be utilized in any combination to suit your specific needs."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "3b4c35a7",
      "metadata": {
        "id": "3b4c35a7"
      },
      "source": [
        "### Composing Conditioners\n",
        "\n",
        "Conditioners in Chroma can be combined seamlessly using `conditioners.ComposedConditioner`, akin to how layers are sequenced in `torch.nn.Sequential`. You can define individual conditioners and then aggregate them into a single collective list which will sequentially apply constrained transformations.\n",
        "\n",
        "```python\n",
        "composed_conditioner = conditioners.ComposedConditioner([conditioner1, conditioner2, conditioner3])\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "#### Setup"
      ],
      "metadata": {
        "id": "b2lOsBQFhypc"
      },
      "id": "b2lOsBQFhypc"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "4db56efd",
      "metadata": {
        "id": "4db56efd"
      },
      "outputs": [],
      "source": [
        "import locale\n",
        "from google.colab import output\n",
        "output.enable_custom_widget_manager()\n",
        "locale.getpreferredencoding = lambda: \"UTF-8\"\n",
        "!pip install generate-chroma > /dev/null 2>&1\n",
        "from chroma import api\n",
        "api.register_key(input(\"Enter API key: \"))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f3ee7c51",
      "metadata": {
        "id": "f3ee7c51"
      },
      "source": [
        "#### Example 1: Combining Symmetry and Secondary Structure\n",
        "In this scenario, we initially apply guidance for secondary structure to condition the content accordingly. This is followed by incorporating Cyclic symmetry. This approach involves adding a secondary structure classifier to conditionally sample an Asymmetric unit (AU) that is beta-rich, followed by symmetrization."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "37b9c48f",
      "metadata": {
        "id": "37b9c48f"
      },
      "outputs": [],
      "source": [
        "from chroma.models import Chroma\n",
        "from chroma.layers.structure import conditioners\n",
        "\n",
        "chroma = Chroma()\n",
        "# Conditional on C=2 (mostly beta)\n",
        "beta = conditioners.ProClassConditioner('cath', \"2\", weight=5, max_norm=20)\n",
        "c_symmetry = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=2)\n",
        "composed_cond = conditioners.ComposedConditioner([beta, c_symmetry])\n",
        "\n",
        "symm_beta = chroma.sample(chain_lengths=[100],\n",
        "    conditioner=composed_cond,\n",
        "    langevin_factor=8,\n",
        "    inverse_temperature=8,\n",
        "    sde_func=\"langevin\",\n",
        "    steps=500)\n",
        "\n",
        "symm_beta"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8b97d729",
      "metadata": {
        "id": "8b97d729"
      },
      "source": [
        "#### Example 2: Merging Symmetry and Substructure\n",
        "Here, our goal is to construct symmetric assemblies from a single-chain protein, partially redesigning it to merge three identical AUs into a Cyclic complex. We begin by defining the backbones targeted for redesign and then reposition the AU to prevent clashes during symmetrization. This is followed by the symmetrization operation itself.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "35c78aad",
      "metadata": {
        "id": "35c78aad"
      },
      "outputs": [],
      "source": [
        "from chroma.data import Protein\n",
        "\n",
        "PDB_ID = '3BDI'\n",
        "chroma = Chroma()\n",
        "\n",
        "protein = Protein(PDB_ID, canonicalize=True, device='cuda')\n",
        "# regenerate residues with X coord < 25 A and y coord < 25 A\n",
        "substruct_conditioner = conditioners.SubstructureConditioner(\n",
        "    protein, backbone_model=chroma.backbone_network, selection=\"x < 25 and y < 25\")\n",
        "\n",
        "# C_3 symmetry\n",
        "c_symmetry = conditioners.SymmetryConditioner(G=\"C_3\", num_chain_neighbors=3)\n",
        "\n",
        "# Composing\n",
        "composed_cond = conditioners.ComposedConditioner([substruct_conditioner, c_symmetry])\n",
        "\n",
        "protein, trajectories = chroma.sample(\n",
        "    protein_init=protein,\n",
        "    conditioner=composed_cond,\n",
        "    langevin_factor=4.0,\n",
        "    langevin_isothermal=True,\n",
        "    inverse_temperature=8.0,\n",
        "    sde_func='langevin',\n",
        "    steps=500,\n",
        "    full_output=True,\n",
        ")\n",
        "\n",
        "protein"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "de3c2b97",
      "metadata": {
        "id": "de3c2b97"
      },
      "source": [
        "### Build your own conditioners: 2D protein lattices\n",
        "\n",
        "An attractive aspect of this conditioner framework is that it is very general, enabling both constraints (which involve operations on $x$) and restraints (which amount to changes to $U$). At the same time, generation under restraints can still be (and often is) challenging, as the resulting effective energy landscape can become arbitrarily rugged and difficult to integrate. We therefore advise caution when using and developing new conditioners or conditioner combinations. We find that inspecting diffusition trajectories (including unconstrained and denoised trajectories, $\\hat{x}_t$ and $\\tilde{x}_t$) can be a good tool for identifying integration challenges and defining either better conditioner forms or better sampling regimes.\n",
        "\n",
        "Here we present how to build a conditioner that generates a periodic 2D lattice. You can easily extend this code snippet to generate 3D protein materials."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2bb9dcf3",
      "metadata": {
        "id": "2bb9dcf3"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "\n",
        "class Lattice2DConditioner(conditioners.Conditioner):\n",
        "    def __init__(self, M, N, cell):\n",
        "        super().__init__()\n",
        "        # Setup the coordinates of a 2D lattice\n",
        "        self.order = M*N\n",
        "        x = torch.arange(M) * cell[0]\n",
        "        y = torch.arange(N) * cell[1]\n",
        "        xx, yy = torch.meshgrid(x, y, indexing=\"ij\")\n",
        "        dX = torch.stack([xx.flatten(), yy.flatten(), torch.zeros(M * N)], dim=1)\n",
        "        self.register_buffer(\"dX\", dX)\n",
        "\n",
        "    def forward(self, X, C, O, U, t):\n",
        "        # Tesselate the unit cell on the lattice\n",
        "        X = (X[:,None,...] + self.dX[None,:,None,None]).reshape(1, -1, 4, 3)\n",
        "        C = torch.cat([C + C.unique().max() * i for i in range(self.dX.shape[0])], dim=1)\n",
        "        # Average the gradient\n",
        "        X.register_hook(lambda gradX: gradX / self.order)\n",
        "        return X, C, O, U, t\n",
        "\n",
        "chroma = Chroma()\n",
        "M, N = 3, 3\n",
        "conditioner = Lattice2DConditioner(M=M, N=N, cell=[25., 25.]).cuda()\n",
        "protein, trajectories = chroma.sample(\n",
        "    chain_lengths=[80], conditioner=conditioner, sde_func='langevin',\n",
        "    potts_symmetry_order=conditioner.order,\n",
        "    full_output=True\n",
        ")\n",
        "\n",
        "protein"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "4ecf48ff",
      "metadata": {
        "id": "4ecf48ff"
      },
      "source": [
        "#### Notes on Troubleshooting\n",
        "1. The sequence in which you apply conditioners is crucial. Generally, it's best to apply stringent and all-encompassing constraints towards the end. For instance, symmetry, a constraint that affects the entire complex, should be implemented last in the conditioner list.\n",
        "When troubleshooting a conditioner, it's helpful to test it on a singular protein state. This helps in verifying if the resulting transformation aligns with your expectations.\n",
        "2. If your conditioner, like the SymmetryConditioner, make copies of a single protein multiple times, it's important to divide the pull-back gradients by the number of protein copies. This prevents excessive gradient accumulation on the protein asymmetric unit, similar to what occurs in the Lattice2DConditioner. Refer to Appendix M for more details.\n",
        "3. Adjusting sampling hyperparameters may be necessary when experimenting with new conditioners. Key parameters to consider include the langevin_factor, inverse_temperature, isothermal settings, steps, and guidance scale (especially when applying restraints). For dealing with hard constraints, it's usually advisable to use sde_func='langevin'."
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.18"
    },
    "colab": {
      "provenance": [],
      "toc_visible": true,
      "gpuType": "T4"
    },
    "accelerator": "GPU"
  },
  "nbformat": 4,
  "nbformat_minor": 5
}